{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sample for Kubeflow PyTorchJob SDK"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a sample for Kubeflow PyTorchJob SDK `kubeflow-pytorchjob`.\n",
    "\n",
    "The notebook shows how to use Kubeflow PyTorchJob SDK to create, get, wait, check and delete PyTorchJob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kubernetes.client import V1PodTemplateSpec\n",
    "from kubernetes.client import V1ObjectMeta\n",
    "from kubernetes.client import V1PodSpec\n",
    "from kubernetes.client import V1Container\n",
    "from kubernetes.client import V1ResourceRequirements\n",
    "\n",
    "from kubeflow.pytorchjob import constants\n",
    "from kubeflow.pytorchjob import utils\n",
    "from kubeflow.pytorchjob import V1ReplicaSpec\n",
    "from kubeflow.pytorchjob import V1PyTorchJob\n",
    "from kubeflow.pytorchjob import V1PyTorchJobSpec\n",
    "from kubeflow.pytorchjob import PyTorchJobClient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define namespace where pytorchjob needs to be created to. If not specified, below function defines namespace to the current one where SDK is running in the cluster, otherwise it will deploy to default namespace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "namespace = utils.get_default_target_namespace()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define PyTorchJob"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The demo only creates a worker of PyTorchJob to run mnist sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "container = V1Container(\n",
    "    name=\"pytorch\",\n",
    "    image=\"gcr.io/kubeflow-ci/pytorch-dist-mnist-test:v1.0\",\n",
    "    args=[\"--backend\",\"gloo\"]\n",
    ")\n",
    "\n",
    "master = V1ReplicaSpec(\n",
    "    replicas=1,\n",
    "    restart_policy=\"OnFailure\",\n",
    "    template=V1PodTemplateSpec(\n",
    "        spec=V1PodSpec(\n",
    "            containers=[container]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "worker = V1ReplicaSpec(\n",
    "    replicas=1,\n",
    "    restart_policy=\"OnFailure\",\n",
    "    template=V1PodTemplateSpec(\n",
    "        spec=V1PodSpec(\n",
    "            containers=[container]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "pytorchjob = V1PyTorchJob(\n",
    "    api_version=\"kubeflow.org/v1\",\n",
    "    kind=\"PyTorchJob\",\n",
    "    metadata=V1ObjectMeta(name=\"pytorch-dist-mnist-gloo\",namespace=namespace),\n",
    "    spec=V1PyTorchJobSpec(\n",
    "        clean_pod_policy=\"None\",\n",
    "        pytorch_replica_specs={\"Master\": master,\n",
    "                               \"Worker\": worker}\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create PyTorchJob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'apiVersion': 'kubeflow.org/v1',\n",
       " 'kind': 'PyTorchJob',\n",
       " 'metadata': {'creationTimestamp': '2020-01-14T09:21:44Z',\n",
       "  'generation': 1,\n",
       "  'name': 'pytorch-dist-mnist-gloo',\n",
       "  'namespace': 'default',\n",
       "  'resourceVersion': '949015',\n",
       "  'selfLink': '/apis/kubeflow.org/v1/namespaces/default/pytorchjobs/pytorch-dist-mnist-gloo',\n",
       "  'uid': '47f9dc9a-36af-11ea-beb5-00163e01f7d2'},\n",
       " 'spec': {'cleanPodPolicy': 'None',\n",
       "  'pytorchReplicaSpecs': {'Master': {'replicas': 1,\n",
       "    'restartPolicy': 'OnFailure',\n",
       "    'template': {'spec': {'containers': [{'args': ['--backend', 'gloo'],\n",
       "        'image': 'gcr.io/kubeflow-ci/pytorch-dist-mnist-test:v1.0',\n",
       "        'name': 'pytorch'}]}}},\n",
       "   'Worker': {'replicas': 1,\n",
       "    'restartPolicy': 'OnFailure',\n",
       "    'template': {'spec': {'containers': [{'args': ['--backend', 'gloo'],\n",
       "        'image': 'gcr.io/kubeflow-ci/pytorch-dist-mnist-test:v1.0',\n",
       "        'name': 'pytorch'}]}}}}}}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pytorchjob_client = PyTorchJobClient()\n",
    "pytorchjob_client.create(pytorchjob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the created PyTorchJob "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'apiVersion': 'kubeflow.org/v1',\n",
       " 'kind': 'PyTorchJob',\n",
       " 'metadata': {'creationTimestamp': '2020-01-14T09:21:44Z',\n",
       "  'generation': 1,\n",
       "  'name': 'pytorch-dist-mnist-gloo',\n",
       "  'namespace': 'default',\n",
       "  'resourceVersion': '949030',\n",
       "  'selfLink': '/apis/kubeflow.org/v1/namespaces/default/pytorchjobs/pytorch-dist-mnist-gloo',\n",
       "  'uid': '47f9dc9a-36af-11ea-beb5-00163e01f7d2'},\n",
       " 'spec': {'cleanPodPolicy': 'None',\n",
       "  'pytorchReplicaSpecs': {'Master': {'replicas': 1,\n",
       "    'restartPolicy': 'OnFailure',\n",
       "    'template': {'spec': {'containers': [{'args': ['--backend', 'gloo'],\n",
       "        'image': 'gcr.io/kubeflow-ci/pytorch-dist-mnist-test:v1.0',\n",
       "        'name': 'pytorch'}]}}},\n",
       "   'Worker': {'replicas': 1,\n",
       "    'restartPolicy': 'OnFailure',\n",
       "    'template': {'spec': {'containers': [{'args': ['--backend', 'gloo'],\n",
       "        'image': 'gcr.io/kubeflow-ci/pytorch-dist-mnist-test:v1.0',\n",
       "        'name': 'pytorch'}]}}}}},\n",
       " 'status': {'conditions': [{'lastTransitionTime': '2020-01-14T09:21:44Z',\n",
       "    'lastUpdateTime': '2020-01-14T09:21:44Z',\n",
       "    'message': 'PyTorchJob pytorch-dist-mnist-gloo is created.',\n",
       "    'reason': 'PyTorchJobCreated',\n",
       "    'status': 'True',\n",
       "    'type': 'Created'}],\n",
       "  'replicaStatuses': {'Master': {}, 'Worker': {}},\n",
       "  'startTime': '2020-01-14T09:21:44Z'}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pytorchjob_client.get('pytorch-dist-mnist-gloo')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the PyTorchJob status, check if the PyTorchJob has been started."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Created'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pytorchjob_client.get_job_status('pytorch-dist-mnist-gloo', namespace=namespace)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wait for the specified PyTorchJob to finish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NAME                           STATE                TIME                          \n",
      "pytorch-dist-mnist-gloo        Created              2020-01-14T09:21:44Z          \n",
      "pytorch-dist-mnist-gloo        Running              2020-01-14T09:23:45Z          \n",
      "pytorch-dist-mnist-gloo        Running              2020-01-14T09:23:45Z          \n",
      "pytorch-dist-mnist-gloo        Running              2020-01-14T09:23:45Z          \n",
      "pytorch-dist-mnist-gloo        Succeeded            2020-01-14T09:28:31Z          \n"
     ]
    }
   ],
   "source": [
    "pytorchjob_client.wait_for_job('pytorch-dist-mnist-gloo', namespace=namespace, watch=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check if the PyTorchJob succeeded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pytorchjob_client.is_job_succeeded('pytorch-dist-mnist-gloo', namespace=namespace)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the PyTorchJob training logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The logs of Pod pytorch-dist-mnist-gloo-master-0:\n",
      " Using distributed PyTorch with gloo backend\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Processing...\n",
      "Done!\n",
      "Train Epoch: 1 [0/60000 (0%)]\tloss=2.3000\n",
      "Train Epoch: 1 [640/60000 (1%)]\tloss=2.2135\n",
      "Train Epoch: 1 [1280/60000 (2%)]\tloss=2.1704\n",
      "Train Epoch: 1 [1920/60000 (3%)]\tloss=2.0766\n",
      "Train Epoch: 1 [2560/60000 (4%)]\tloss=1.8679\n",
      "Train Epoch: 1 [3200/60000 (5%)]\tloss=1.4135\n",
      "Train Epoch: 1 [3840/60000 (6%)]\tloss=1.0003\n",
      "Train Epoch: 1 [4480/60000 (7%)]\tloss=0.7762\n",
      "Train Epoch: 1 [5120/60000 (9%)]\tloss=0.4598\n",
      "Train Epoch: 1 [5760/60000 (10%)]\tloss=0.4860\n",
      "Train Epoch: 1 [6400/60000 (11%)]\tloss=0.4389\n",
      "Train Epoch: 1 [7040/60000 (12%)]\tloss=0.4084\n",
      "Train Epoch: 1 [7680/60000 (13%)]\tloss=0.4602\n",
      "Train Epoch: 1 [8320/60000 (14%)]\tloss=0.4289\n",
      "Train Epoch: 1 [8960/60000 (15%)]\tloss=0.3990\n",
      "Train Epoch: 1 [9600/60000 (16%)]\tloss=0.3850\n",
      "Train Epoch: 1 [10240/60000 (17%)]\tloss=0.2985\n",
      "Train Epoch: 1 [10880/60000 (18%)]\tloss=0.5031\n",
      "Train Epoch: 1 [11520/60000 (19%)]\tloss=0.5235\n",
      "Train Epoch: 1 [12160/60000 (20%)]\tloss=0.3379\n",
      "Train Epoch: 1 [12800/60000 (21%)]\tloss=0.3667\n",
      "Train Epoch: 1 [13440/60000 (22%)]\tloss=0.4503\n",
      "Train Epoch: 1 [14080/60000 (23%)]\tloss=0.3043\n",
      "Train Epoch: 1 [14720/60000 (25%)]\tloss=0.3589\n",
      "Train Epoch: 1 [15360/60000 (26%)]\tloss=0.3320\n",
      "Train Epoch: 1 [16000/60000 (27%)]\tloss=0.4406\n",
      "Train Epoch: 1 [16640/60000 (28%)]\tloss=0.3641\n",
      "Train Epoch: 1 [17280/60000 (29%)]\tloss=0.3170\n",
      "Train Epoch: 1 [17920/60000 (30%)]\tloss=0.2014\n",
      "Train Epoch: 1 [18560/60000 (31%)]\tloss=0.4985\n",
      "Train Epoch: 1 [19200/60000 (32%)]\tloss=0.3264\n",
      "Train Epoch: 1 [19840/60000 (33%)]\tloss=0.1198\n",
      "Train Epoch: 1 [20480/60000 (34%)]\tloss=0.1904\n",
      "Train Epoch: 1 [21120/60000 (35%)]\tloss=0.1424\n",
      "Train Epoch: 1 [21760/60000 (36%)]\tloss=0.3143\n",
      "Train Epoch: 1 [22400/60000 (37%)]\tloss=0.1494\n",
      "Train Epoch: 1 [23040/60000 (38%)]\tloss=0.2901\n",
      "Train Epoch: 1 [23680/60000 (39%)]\tloss=0.4670\n",
      "Train Epoch: 1 [24320/60000 (41%)]\tloss=0.2151\n",
      "Train Epoch: 1 [24960/60000 (42%)]\tloss=0.1521\n",
      "Train Epoch: 1 [25600/60000 (43%)]\tloss=0.2240\n",
      "Train Epoch: 1 [26240/60000 (44%)]\tloss=0.2629\n",
      "Train Epoch: 1 [26880/60000 (45%)]\tloss=0.2330\n",
      "Train Epoch: 1 [27520/60000 (46%)]\tloss=0.2630\n",
      "Train Epoch: 1 [28160/60000 (47%)]\tloss=0.2126\n",
      "Train Epoch: 1 [28800/60000 (48%)]\tloss=0.1327\n",
      "Train Epoch: 1 [29440/60000 (49%)]\tloss=0.2789\n",
      "Train Epoch: 1 [30080/60000 (50%)]\tloss=0.0947\n",
      "Train Epoch: 1 [30720/60000 (51%)]\tloss=0.1280\n",
      "Train Epoch: 1 [31360/60000 (52%)]\tloss=0.2458\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tloss=0.3394\n",
      "Train Epoch: 1 [32640/60000 (54%)]\tloss=0.1527\n",
      "Train Epoch: 1 [33280/60000 (55%)]\tloss=0.0901\n",
      "Train Epoch: 1 [33920/60000 (57%)]\tloss=0.1451\n",
      "Train Epoch: 1 [34560/60000 (58%)]\tloss=0.1994\n",
      "Train Epoch: 1 [35200/60000 (59%)]\tloss=0.2171\n",
      "Train Epoch: 1 [35840/60000 (60%)]\tloss=0.0633\n",
      "Train Epoch: 1 [36480/60000 (61%)]\tloss=0.1369\n",
      "Train Epoch: 1 [37120/60000 (62%)]\tloss=0.1160\n",
      "Train Epoch: 1 [37760/60000 (63%)]\tloss=0.2355\n",
      "Train Epoch: 1 [38400/60000 (64%)]\tloss=0.0634\n",
      "Train Epoch: 1 [39040/60000 (65%)]\tloss=0.1062\n",
      "Train Epoch: 1 [39680/60000 (66%)]\tloss=0.1608\n",
      "Train Epoch: 1 [40320/60000 (67%)]\tloss=0.1101\n",
      "Train Epoch: 1 [40960/60000 (68%)]\tloss=0.1775\n",
      "Train Epoch: 1 [41600/60000 (69%)]\tloss=0.2285\n",
      "Train Epoch: 1 [42240/60000 (70%)]\tloss=0.0737\n",
      "Train Epoch: 1 [42880/60000 (71%)]\tloss=0.1562\n",
      "Train Epoch: 1 [43520/60000 (72%)]\tloss=0.2775\n",
      "Train Epoch: 1 [44160/60000 (74%)]\tloss=0.1418\n",
      "Train Epoch: 1 [44800/60000 (75%)]\tloss=0.1163\n",
      "Train Epoch: 1 [45440/60000 (76%)]\tloss=0.1221\n",
      "Train Epoch: 1 [46080/60000 (77%)]\tloss=0.0768\n",
      "Train Epoch: 1 [46720/60000 (78%)]\tloss=0.1950\n",
      "Train Epoch: 1 [47360/60000 (79%)]\tloss=0.0706\n",
      "Train Epoch: 1 [48000/60000 (80%)]\tloss=0.2091\n",
      "Train Epoch: 1 [48640/60000 (81%)]\tloss=0.1380\n",
      "Train Epoch: 1 [49280/60000 (82%)]\tloss=0.0950\n",
      "Train Epoch: 1 [49920/60000 (83%)]\tloss=0.1070\n",
      "Train Epoch: 1 [50560/60000 (84%)]\tloss=0.1194\n",
      "Train Epoch: 1 [51200/60000 (85%)]\tloss=0.1447\n",
      "Train Epoch: 1 [51840/60000 (86%)]\tloss=0.0662\n",
      "Train Epoch: 1 [52480/60000 (87%)]\tloss=0.0239\n",
      "Train Epoch: 1 [53120/60000 (88%)]\tloss=0.2622\n",
      "Train Epoch: 1 [53760/60000 (90%)]\tloss=0.0928\n",
      "Train Epoch: 1 [54400/60000 (91%)]\tloss=0.1297\n",
      "Train Epoch: 1 [55040/60000 (92%)]\tloss=0.1907\n",
      "Train Epoch: 1 [55680/60000 (93%)]\tloss=0.0347\n",
      "Train Epoch: 1 [56320/60000 (94%)]\tloss=0.0354\n",
      "Train Epoch: 1 [56960/60000 (95%)]\tloss=0.0770\n",
      "Train Epoch: 1 [57600/60000 (96%)]\tloss=0.1175\n",
      "Train Epoch: 1 [58240/60000 (97%)]\tloss=0.1919\n",
      "Train Epoch: 1 [58880/60000 (98%)]\tloss=0.2053\n",
      "Train Epoch: 1 [59520/60000 (99%)]\tloss=0.0639\n",
      "\n",
      "accuracy=0.9664\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pytorchjob_client.get_logs('pytorch-dist-mnist-gloo', namespace=namespace)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Delete the PyTorchJob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'kind': 'Status',\n",
       " 'apiVersion': 'v1',\n",
       " 'metadata': {},\n",
       " 'status': 'Success',\n",
       " 'details': {'name': 'pytorch-dist-mnist-gloo',\n",
       "  'group': 'kubeflow.org',\n",
       "  'kind': 'pytorchjobs',\n",
       "  'uid': '47f9dc9a-36af-11ea-beb5-00163e01f7d2'}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pytorchjob_client.delete('pytorch-dist-mnist-gloo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
